{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6fa9c631",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import pickle\n",
    "import os\n",
    "from PIL import Image\n",
    "from pathlib import Path\n",
    "from tqdm import tqdm\n",
    "import dnnlib, legacy\n",
    "import clip\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms as T\n",
    "from tqdm import tqdm\n",
    "import scipy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b93b1f9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Generator:\n",
    "    def __init__(self, device, path):\n",
    "        self.name = 'generator'\n",
    "        self.model = self.load_model(device, path)\n",
    "        self.device = device\n",
    "        self.force_32 = False\n",
    "        \n",
    "    def load_model(self, device, path):\n",
    "        with dnnlib.util.open_url(path) as f:\n",
    "            network= legacy.load_network_pkl(f)\n",
    "            self.G_ema = network['G_ema'].to(device)\n",
    "            self.D = network['D'].to(device)\n",
    "#                 self.G = network['G'].to(device)\n",
    "            return self.G_ema\n",
    "        \n",
    "    def generate(self, z, c, fts, noise_mode='const', return_styles=True):\n",
    "        return self.model(z, c, fts=fts, noise_mode=noise_mode, return_styles=return_styles, force_fp32=self.force_32)\n",
    "    \n",
    "    def generate_from_style(self, style, noise_mode='const'):\n",
    "        ws = torch.randn(1, self.model.num_ws, 512)\n",
    "        return self.model.synthesis(ws, fts=None, styles=style, noise_mode=noise_mode, force_fp32=self.force_32)\n",
    "    \n",
    "    def tensor_to_img(self, tensor):\n",
    "        img = torch.clamp((tensor + 1.) * 127.5, 0., 255.)\n",
    "        img_list = img.permute(0, 2, 3, 1)\n",
    "        img_list = [img for img in img_list]\n",
    "        return Image.fromarray(torch.cat(img_list, dim=-2).detach().cpu().numpy().astype(np.uint8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "29873681",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "\n",
    "    device = 'cuda:0' # please use GPU, do not use CPU\n",
    "    path = './some_pre-trained_models.pkl'  # pre-trained model\n",
    "    generator = Generator(device=device, path=path)\n",
    "    clip_model, _ = clip.load(\"ViT-B/32\", device=device)\n",
    "    clip_model = clip_model.eval()\n",
    "    \n",
    "    txt = 'sentence'  # input sentence\n",
    "    tokenized_text = clip.tokenize([txt]).to(device)\n",
    "    txt_fts = clip_model.encode_text(tokenized_text)\n",
    "    txt_fts = txt_fts/txt_fts.norm(dim=-1, keepdim=True)\n",
    "    \n",
    "    z = torch.randn((1, 512)).to(device)\n",
    "    c = torch.randn((1, 1)).to(device) # label is actually not used\n",
    "    img, _ = generator.generate(z=z, c=c, fts=txt_fts)\n",
    "    to_show_img = generator.tensor_to_img(img)\n",
    "    to_show_img.save('./generated.jpg')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}